/*
 * AXERA is pleased to support the open source community by making ax-samples available.
 *
 * Copyright (c) 2022, AXERA Semiconductor (Shanghai) Co., Ltd. All rights reserved.
 *
 * Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
 * in compliance with the License. You may obtain a copy of the License at
 *
 * https://opensource.org/licenses/BSD-3-Clause
 *
 * Unless required by applicable law or agreed to in writing, software distributed
 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations under the License.
 */

/*
 * Author: ZHEQIUSHUI
 */

#include <cstdio>
#include <cstring>
#include <numeric>

#include <opencv2/opencv.hpp>
#include "base/common.hpp"
#include "middleware/io.hpp"

#include "utilities/args.hpp"
#include "utilities/cmdline.hpp"
#include "utilities/file.hpp"
#include "utilities/timer.hpp"

#include <ax_sys_api.h>
#include <ax_engine_api.h>

#include "base/score.hpp"
#include "base/topk.hpp"

const int DEFAULT_IMG_H = 256;
const int DEFAULT_IMG_W = 192;
const int DEFAULT_LOOP_COUNT = 1;

namespace ax
{
    int find_max(float *ptr, int len, float &max_val)
    {
        max_val = -FLT_MAX;
        int max_idx = -1;
        for (size_t i = 0; i < len; i++)
        {
            if (ptr[i] > max_val)
            {
                max_val = ptr[i];
                max_idx = i;
            }
        }
        return max_idx;
    }

    void post_process(AX_ENGINE_IO_INFO_T *io_info, AX_ENGINE_IO_T *io_data, const cv::Mat &mat, const std::vector<float> &time_costs)
    {
        // https://github.com/PaddlePaddle/PaddleClas/blob/a89269e5393ad6106277199650e3cc411ddee61c/deploy/python/postprocess.py#L192

        timer timer_postprocess;

        static float threshold = 0.5;
        static float glasses_threshold = 0.3;
        static float hold_threshold = 0.6;

        auto &output = io_data->pOutputs[0];
        auto &info = io_info->pOutputs[0];
        auto ptr = (float *)output.pVirAddr;

        static const char *age_list_en[] = {"AgeLess18", "Age18-60", "AgeOver60"};
        static const char *direct_list_en[] = {"Front", "Side", "Back"};
        static const char *bag_list_en[] = {"HandBag", "ShoulderBag", "Backpack"};
        static const char *sleeve_list_en[] = {"ShortSleeve", "LongSleeve"};
        static const char *upper_list_en[] = {"UpperStride", "UpperLogo", "UpperPlaid", "UpperSplice"};
        static const char *lower_list_en[] = {"LowerStripe", "LowerPattern", "LongCoat", "Trousers", "Shorts", "Skirt&Dress"};

        static const char *age_list_chs[] = {"小于18岁", "18到60岁", "大于60岁"};
        static const char *direct_list_chs[] = {"前", "侧", "后"};
        static const char *bag_list_chs[] = {"手包", "肩袋", "背包"};
        static const char *sleeve_list_chs[] = {"短袖", "长袖"};
        static const char *upper_list_chs[] = {"条纹", "图案", "格子", "拼接"};
        static const char *lower_list_chs[] = {"下条纹", "下图案", "长外套", "长裤", "短裤", "短裙&长裙"};

        const char **age_list = age_list_chs;
        const char **direct_list = direct_list_chs;
        const char **bag_list = bag_list_chs;
        const char **sleeve_list = sleeve_list_chs;
        const char **upper_list = upper_list_chs;
        const char **lower_list = lower_list_chs;

        const char *hat = "False";
        if (ptr[0] > threshold)
            hat = "True";

        const char *glasses = "False";
        if (ptr[1] > glasses_threshold)
            glasses = "True";

        const char *shoe = "NoBoots";
        if (ptr[4] > threshold)
            shoe = "Boots";

        const char *hold_obj = "False";
        if (ptr[18] > hold_threshold)
            hold_obj = "True";

        const char *gender = "Male";
        if (ptr[22] > threshold)
            gender = "Female";

        std::string upper_label = "";

        upper_label += sleeve_list[ptr[3] > ptr[2] ? 1 : 0];
        upper_label += " # ";

        float *upper_ptr = ptr + 4;
        for (size_t i = 0; i < 4; i++)
        {
            if (upper_ptr[i] > threshold)
            {
                upper_label += upper_list[i];
                upper_label += " # ";
            }
        }

        std::string lower_label = "";
        float *lower_ptr = ptr + 8;
        for (size_t i = 0; i < 6; i++)
        {
            if (lower_ptr[i] > threshold)
            {
                lower_label += lower_list[i];
                lower_label += " # ";
            }
        }

        float bag_prob_max_val;
        int bag_idx = find_max(ptr + 15, 3, bag_prob_max_val);
        const char *bag = bag_list[bag_idx];
        if (bag_prob_max_val < threshold)
        {
            bag = "NoBag";
        }

        float age_prob_max_val;
        int age_idx = find_max(ptr + 19, 3, age_prob_max_val);
        const char *age = age_list[age_idx];

        float direction_prob_max_val;
        int direction_idx = find_max(ptr + 23, 3, direction_prob_max_val);
        const char *direction = direct_list[direction_idx];

        fprintf(stdout, "cost time:%.2f ms \n", timer_postprocess.cost());

        fprintf(stdout, "          attr prob\n"
                        "Gender   :%s %0.2f\n"
                        "Age      :%s %0.2f\n"
                        "Glasses  :%s %0.2f\n"
                        "Hat      :%s %0.2f\n"
                        "Direction:%s %0.2f\n"
                        "HoldObj  :%s %0.2f\n"
                        "Shoe     :%s %0.2f\n"
                        "Bag      :%s %0.2f\n"
                        "Upper    :%s \n"
                        "Lower    :%s \n ",
                gender, ptr[22], age, age_prob_max_val, glasses, ptr[1], hat, ptr[0], direction, direction_prob_max_val, hold_obj, ptr[18], shoe, ptr[4], bag, bag_prob_max_val, upper_label.c_str(), lower_label.c_str());

        fprintf(stdout, "--------------------------------------\n");
        auto total_time = std::accumulate(time_costs.begin(), time_costs.end(), 0.f);
        auto min_max_time = std::minmax_element(time_costs.begin(), time_costs.end());
        fprintf(stdout,
                "Repeat %d times, avg time %.2f ms, max_time %.2f ms, min_time %.2f ms\n",
                (int)time_costs.size(),
                total_time / (float)time_costs.size(),
                *min_max_time.second,
                *min_max_time.first);
    }

    bool run_model(const std::string &model, const std::vector<uint8_t> &data, const int &repeat, cv::Mat &mat)
    {
        // 1. init engine
        AX_ENGINE_NPU_ATTR_T npu_attr;
        memset(&npu_attr, 0, sizeof(npu_attr));
        npu_attr.eHardMode = AX_ENGINE_VIRTUAL_NPU_DISABLE;
        auto ret = AX_ENGINE_Init(&npu_attr);
        if (0 != ret)
        {
            return ret;
        }

        // 2. load model
        std::vector<char> model_buffer;
        if (!utilities::read_file(model, model_buffer))
        {
            fprintf(stderr, "Read Run-Joint model(%s) file failed.\n", model.c_str());
            return false;
        }

        // 3. create handle
        AX_ENGINE_HANDLE handle;
        ret = AX_ENGINE_CreateHandle(&handle, model_buffer.data(), model_buffer.size());
        SAMPLE_AX_ENGINE_DEAL_HANDLE
        fprintf(stdout, "Engine creating handle is done.\n");

        // 4. create context
        ret = AX_ENGINE_CreateContext(handle);
        SAMPLE_AX_ENGINE_DEAL_HANDLE
        fprintf(stdout, "Engine creating context is done.\n");

        // 5. set io
        AX_ENGINE_IO_INFO_T *io_info;
        ret = AX_ENGINE_GetIOInfo(handle, &io_info);
        SAMPLE_AX_ENGINE_DEAL_HANDLE
        fprintf(stdout, "Engine get io info is done. \n");

        // 6. alloc io
        AX_ENGINE_IO_T io_data;
        ret = middleware::prepare_io(io_info, &io_data, std::make_pair(AX_ENGINE_ABST_DEFAULT, AX_ENGINE_ABST_CACHED));
        SAMPLE_AX_ENGINE_DEAL_HANDLE
        fprintf(stdout, "Engine alloc io is done. \n");

        // 7. insert input
        ret = middleware::push_input(data, &io_data, io_info);
        SAMPLE_AX_ENGINE_DEAL_HANDLE_IO
        fprintf(stdout, "Engine push input is done. \n");
        fprintf(stdout, "--------------------------------------\n");

        // 8. warn up
        for (int i = 0; i < 5; ++i)
        {
            AX_ENGINE_RunSync(handle, &io_data);
        }

        // 9. run model
        std::vector<float> time_costs(repeat, 0);
        for (int i = 0; i < repeat; ++i)
        {
            timer tick;
            ret = AX_ENGINE_RunSync(handle, &io_data);
            time_costs[i] = tick.cost();
            SAMPLE_AX_ENGINE_DEAL_HANDLE_IO
        }

        // 10. get result
        post_process(io_info, &io_data, mat, time_costs);
        fprintf(stdout, "--------------------------------------\n");

        middleware::free_io(&io_data);
        return AX_ENGINE_DestroyHandle(handle);
    }
} // namespace ax

int main(int argc, char *argv[])
{
    cmdline::parser cmd;
    cmd.add<std::string>("model", 'm', "joint file(a.k.a. joint model)", true, "");
    cmd.add<std::string>("image", 'i', "image file", true, "");
    cmd.add<std::string>("size", 'g', "input_h, input_w", false, std::to_string(DEFAULT_IMG_H) + "," + std::to_string(DEFAULT_IMG_W));

    cmd.add<int>("repeat", 'r', "repeat count", false, DEFAULT_LOOP_COUNT);
    cmd.parse_check(argc, argv);

    // 0. get app args, can be removed from user's app
    auto model_file = cmd.get<std::string>("model");
    auto image_file = cmd.get<std::string>("image");

    auto model_file_flag = utilities::file_exist(model_file);
    auto image_file_flag = utilities::file_exist(image_file);

    if (!model_file_flag | !image_file_flag)
    {
        auto show_error = [](const std::string &kind, const std::string &value)
        {
            fprintf(stderr, "Input file %s(%s) is not exist, please check it.\n", kind.c_str(), value.c_str());
        };

        if (!model_file_flag)
        {
            show_error("model", model_file);
        }
        if (!image_file_flag)
        {
            show_error("image", image_file);
        }

        return -1;
    }

    auto input_size_string = cmd.get<std::string>("size");

    std::array<int, 2> input_size = {DEFAULT_IMG_H, DEFAULT_IMG_W};

    auto input_size_flag = utilities::parse_string(input_size_string, input_size);

    if (!input_size_flag)
    {
        auto show_error = [](const std::string &kind, const std::string &value)
        {
            fprintf(stderr, "Input %s(%s) is not allowed, please check it.\n", kind.c_str(), value.c_str());
        };

        show_error("size", input_size_string);

        return -1;
    }

    auto repeat = cmd.get<int>("repeat");

    // 1. print args
    fprintf(stdout, "--------------------------------------\n");
    fprintf(stdout, "model file : %s\n", model_file.c_str());
    fprintf(stdout, "image file : %s\n", image_file.c_str());
    fprintf(stdout, "img_h, img_w : %d %d\n", input_size[0], input_size[1]);
    fprintf(stdout, "--------------------------------------\n");

    // 2. read image & resize & transpose
    std::vector<uint8_t> image(input_size[0] * input_size[1] * 3, 0);
    cv::Mat mat = cv::imread(image_file);
    if (mat.empty())
    {
        fprintf(stderr, "Read image failed.\n");
        return -1;
    }
    cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB);
    common::get_input_data_centercrop(mat, image, input_size[0], input_size[1]);

    // 3. sys_init
    AX_SYS_Init();

    // 4. -  engine model  -  can only use AX_ENGINE** inside
    {
        // AX_ENGINE_NPUReset(); // todo ??
        ax::run_model(model_file, image, repeat, mat);

        // 4.3 engine de init
        AX_ENGINE_Deinit();
        // AX_ENGINE_NPUReset();
    }
    // 4. -  engine model  -

    AX_SYS_Deinit();
    return 0;
}
